package encrypt

import (
	"bytes"
	"crypto/aes"
	"crypto/cipher"
	"encoding/base64"
	"encoding/hex"
	"errors"
	"fmt"
)

type (
	// AESHandler AES CBC 加密模式
	AESHandler struct {
		CipherHandler
	}

	// CipherHandler AES密钥
	CipherHandler struct {
		Secret string
		Iv     string
		Output bool // false默认十六进制，true base64
	}
)

// Cipher AES密钥
func (h *AESHandler) Cipher(handler CipherHandler) *AESHandler {
	h.CipherHandler = handler
	return h
}

// Encrypt 使用AES-CBC加密数据，根据密钥长度选择AES-128, AES-192, 或 AES-256
func (h *AESHandler) Encrypt(plaintext []byte) (ciphertext []byte, err error) {
	block, err := aes.NewCipher([]byte(h.Secret))
	if err != nil {
		return nil, fmt.Errorf(`instance AES failed %w`, err)
	}

	// 添加填充值
	plaintext, err = h.PKCS7Padding(plaintext, block.BlockSize())
	if err != nil {
		return nil, err
	}

	// 初始化
	ciphertext = make([]byte, len(plaintext))

	// 实例加密
	cipher.NewCBCEncrypter(block, h.GetIv()).CryptBlocks(ciphertext, plaintext)

	return
}

// EncryptString 使用AES-CBC加密数据，根据密钥长度选择AES-128, AES-192, 或 AES-256, 转换为十六进制或base64
func (h *AESHandler) EncryptString(plaintext []byte) (string, error) {
	rawCipher, err := h.Encrypt(plaintext)
	if err != nil {
		return "", err
	}
	var ciphertext string
	// 输出格式
	switch h.Output {
	case true:
		// 转换Base64
		ciphertext = base64.StdEncoding.EncodeToString(rawCipher)
	default:
		// 转换为十六进制
		ciphertext = hex.EncodeToString(rawCipher)
	}
	return ciphertext, nil
}

// MustEncryptString 必须加密字符串 使用AES-CBC加密数据，根据密钥长度选择AES-128, AES-192, 或 AES-256, 转换为十六进制或base64
func (h *AESHandler) MustEncryptString(plaintext []byte) string {
	ciphertext, err := h.EncryptString(plaintext)
	if err != nil {
		panic(err)
	}
	return ciphertext
}

// Decrypt 使用AES-CBC解密数据，根据密钥长度选择AES-128, AES-192, 或 AES-256
func (h *AESHandler) Decrypt(ciphertext []byte) (plaintext []byte, err error) {
	if len(ciphertext) == 0 {
		return nil, errors.New(`ciphertext empty`)
	}

	// 密文格式错误
	if len(ciphertext) < aes.BlockSize {
		return nil, errors.New(`ciphertext too short or format error`)
	}

	// 实例AES模块
	block, err := aes.NewCipher([]byte(h.Secret))
	if err != nil {
		return nil, err
	}

	// 初始化
	plaintext = make([]byte, len(ciphertext))

	// 使用CBC解密
	cipher.NewCBCDecrypter(block, h.GetIv()).CryptBlocks(plaintext, ciphertext)

	// 去除填充数
	return h.PKCS7UnPadding(plaintext)
}

// DecryptString 解密原始密文并逆解析转义密文
func (h *AESHandler) DecryptString(rawCipher string) ([]byte, error) {
	var (
		ciphertext []byte
		err        error
	)

	// 按输出类型逆解析为Base格式
	switch h.Output {
	case true:
		// 解析Base64格式
		ciphertext, err = base64.StdEncoding.DecodeString(rawCipher)
	default:
		// 解析十六进制
		ciphertext, err = hex.DecodeString(rawCipher)
	}
	if err != nil {
		return nil, fmt.Errorf(`parse AES raw ciphertext failed %w`, err)
	}

	// 密文解密
	return h.Decrypt(ciphertext)
}

// GetIv 获取初始化IV向量
func (h *AESHandler) GetIv() []byte {
	iv := []byte(h.Iv) // 初始化指定的IV
	if len(iv) == 0 {
		// 不指定IV，初始化向量IV应该是不可预测的随机字节
		iv = []byte(h.Secret)[:aes.BlockSize]
	}
	return iv
}

// PKCS7Padding 实现了PKCS#7的填充方式
func (h *AESHandler) PKCS7Padding(src []byte, blockSize int) ([]byte, error) {
	// 根据密钥长度自动填充CBC-128, CBC-192, CBC-256三种方式
	if len(src) == 0 {
		return nil, errors.New(`encrypt plaintext empty`)
	}
	// 填充干扰值
	padding := blockSize - len(src)%blockSize
	if padding == 0 {
		// 已对齐
		return append(src, bytes.Repeat([]byte{byte(blockSize)}, blockSize)...), nil
	}
	// 未对齐填充补齐
	return append(src, bytes.Repeat([]byte{byte(padding)}, padding)...), nil
}

// PKCS7UnPadding 实现了PKCS#7的去填充方式
func (h *AESHandler) PKCS7UnPadding(src []byte) ([]byte, error) {
	length := len(src)
	if length == 0 {
		return nil, errors.New(`invalid ciphertext empty`)
	}

	// 填充长度
	unPadding := int(src[length-1])
	if unPadding > length {
		return nil, errors.New(`invalid padding number`)
	}

	// 取出密文
	return src[:(length - unPadding)], nil
}
